import logging
import gym
from typing import Dict, Tuple

import ray
from ray.rllib.agents.ddpg.ddpg_tf_policy import (
    build_ddpg_models,
    get_distribution_inputs_and_class,
    validate_spaces,
)
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, PRIO_WEIGHTS
from ray.rllib.agents.sac.sac_torch_policy import TargetNetworkMixin
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic, TorchDirichlet
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.torch_utils import (
    apply_grad_clipping,
    concat_multi_gpu_td_errors,
    huber_loss,
    l2_loss,
)
from ray.rllib.utils.typing import (
    TrainerConfigDict,
    TensorType,
    LocalOptimizer,
    GradInfoDict,
)

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


def build_ddpg_models_and_action_dist(
    policy: Policy,
    obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict,
) -> Tuple[ModelV2, ActionDistribution]:
    model = build_ddpg_models(policy, obs_space, action_space, config)

    if isinstance(action_space, Simplex):
        return model, TorchDirichlet
    else:
        return model, TorchDeterministic


def ddpg_actor_critic_loss(
    policy: Policy, model: ModelV2, _, train_batch: SampleBatch
) -> TensorType:

    target_model = policy.target_models[model]

    twin_q = policy.config["twin_q"]
    gamma = policy.config["gamma"]
    n_step = policy.config["n_step"]
    use_huber = policy.config["use_huber"]
    huber_threshold = policy.config["huber_threshold"]
    l2_reg = policy.config["l2_reg"]

    input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True)
    input_dict_next = SampleBatch(
        obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True
    )

    model_out_t, _ = model(input_dict, [], None)
    model_out_tp1, _ = model(input_dict_next, [], None)
    target_model_out_tp1, _ = target_model(input_dict_next, [], None)

    # Policy network evaluation.
    # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
    policy_t = model.get_policy_output(model_out_t)
    # policy_batchnorm_update_ops = list(
    #    set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

    policy_tp1 = target_model.get_policy_output(target_model_out_tp1)

    # Action outputs.
    if policy.config["smooth_target_policy"]:
        target_noise_clip = policy.config["target_noise_clip"]
        clipped_normal_sample = torch.clamp(
            torch.normal(
                mean=torch.zeros(policy_tp1.size()), std=policy.config["target_noise"]
            ).to(policy_tp1.device),
            -target_noise_clip,
            target_noise_clip,
        )

        policy_tp1_smoothed = torch.min(
            torch.max(
                policy_tp1 + clipped_normal_sample,
                torch.tensor(
                    policy.action_space.low,
                    dtype=torch.float32,
                    device=policy_tp1.device,
                ),
            ),
            torch.tensor(
                policy.action_space.high, dtype=torch.float32, device=policy_tp1.device
            ),
        )
    else:
        # No smoothing, just use deterministic actions.
        policy_tp1_smoothed = policy_tp1

    # Q-net(s) evaluation.
    # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
    # Q-values for given actions & observations in given current
    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])

    # Q-values for current policy (no noise) in given current state
    q_t_det_policy = model.get_q_values(model_out_t, policy_t)

    actor_loss = -torch.mean(q_t_det_policy)

    if twin_q:
        twin_q_t = model.get_twin_q_values(
            model_out_t, train_batch[SampleBatch.ACTIONS]
        )
    # q_batchnorm_update_ops = list(
    #     set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

    # Target q-net(s) evaluation.
    q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed)

    if twin_q:
        twin_q_tp1 = target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1_smoothed
        )

    q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1)
    if twin_q:
        twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
        q_tp1 = torch.min(q_tp1, twin_q_tp1)

    q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
    q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # Compute RHS of bellman equation.
    q_t_selected_target = (
        train_batch[SampleBatch.REWARDS] + gamma ** n_step * q_tp1_best_masked
    ).detach()

    # Compute the error (potentially clipped).
    if twin_q:
        td_error = q_t_selected - q_t_selected_target
        twin_td_error = twin_q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold) + huber_loss(
                twin_td_error, huber_threshold
            )
        else:
            errors = 0.5 * (torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0))
    else:
        td_error = q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold)
        else:
            errors = 0.5 * torch.pow(td_error, 2.0)

    critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors)

    # Add l2-regularization if required.
    if l2_reg is not None:
        for name, var in model.policy_variables(as_dict=True).items():
            if "bias" not in name:
                actor_loss += l2_reg * l2_loss(var)
        for name, var in model.q_variables(as_dict=True).items():
            if "bias" not in name:
                critic_loss += l2_reg * l2_loss(var)

    # Model self-supervised losses.
    if policy.config["use_state_preprocessor"]:
        # Expand input_dict in case custom_loss' need them.
        input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
        input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
        input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
        input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
        [actor_loss, critic_loss] = model.custom_loss(
            [actor_loss, critic_loss], input_dict
        )

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    # Return two loss terms (corresponding to the two optimizers, we create).
    return actor_loss, critic_loss


def make_ddpg_optimizers(
    policy: Policy, config: TrainerConfigDict
) -> Tuple[LocalOptimizer]:
    """Create separate optimizers for actor & critic losses."""

    # Set epsilons to match tf.keras.optimizers.Adam's epsilon default.
    policy._actor_optimizer = torch.optim.Adam(
        params=policy.model.policy_variables(), lr=config["actor_lr"], eps=1e-7
    )

    policy._critic_optimizer = torch.optim.Adam(
        params=policy.model.q_variables(), lr=config["critic_lr"], eps=1e-7
    )

    # Return them in the same order as the respective loss terms are returned.
    return policy._actor_optimizer, policy._critic_optimizer


def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None:
    # For policy gradient, update policy net one time v.s.
    # update critic net `policy_delay` time(s).
    if policy.global_step % policy.config["policy_delay"] == 0:
        policy._actor_optimizer.step()

    policy._critic_optimizer.step()

    # Increment global step & apply ops.
    policy.global_step += 1


def build_ddpg_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:

    q_t = torch.stack(policy.get_tower_stats("q_t"))
    stats = {
        "actor_loss": torch.mean(torch.stack(policy.get_tower_stats("actor_loss"))),
        "critic_loss": torch.mean(torch.stack(policy.get_tower_stats("critic_loss"))),
        "mean_q": torch.mean(q_t),
        "max_q": torch.max(q_t),
        "min_q": torch.min(q_t),
    }
    return stats


def before_init_fn(
    policy: Policy,
    obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict,
) -> None:
    # Create global step for counting the number of update operations.
    policy.global_step = 0


class ComputeTDErrorMixin:
    def __init__(self, loss_fn):
        def compute_td_error(
            obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights
        ):
            input_dict = self._lazy_tensor_dict(
                SampleBatch(
                    {
                        SampleBatch.CUR_OBS: obs_t,
                        SampleBatch.ACTIONS: act_t,
                        SampleBatch.REWARDS: rew_t,
                        SampleBatch.NEXT_OBS: obs_tp1,
                        SampleBatch.DONES: done_mask,
                        PRIO_WEIGHTS: importance_weights,
                    }
                )
            )
            # Do forward pass on loss to update td errors attribute
            # (one TD-error value per item in batch to update PR weights).
            loss_fn(self, self.model, None, input_dict)

            # `self.model.td_error` is set within actor_critic_loss call.
            return self.model.tower_stats["td_error"]

        self.compute_td_error = compute_td_error


def setup_late_mixins(
    policy: Policy,
    obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    config: TrainerConfigDict,
) -> None:
    ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
    TargetNetworkMixin.__init__(policy)


DDPGTorchPolicy = build_policy_class(
    name="DDPGTorchPolicy",
    framework="torch",
    loss_fn=ddpg_actor_critic_loss,
    get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
    stats_fn=build_ddpg_stats,
    postprocess_fn=postprocess_nstep_and_prio,
    extra_grad_process_fn=apply_grad_clipping,
    optimizer_fn=make_ddpg_optimizers,
    validate_spaces=validate_spaces,
    before_init=before_init_fn,
    before_loss_init=setup_late_mixins,
    action_distribution_fn=get_distribution_inputs_and_class,
    make_model_and_action_dist=build_ddpg_models_and_action_dist,
    extra_learn_fetches_fn=concat_multi_gpu_td_errors,
    apply_gradients_fn=apply_gradients_fn,
    mixins=[
        TargetNetworkMixin,
        ComputeTDErrorMixin,
    ],
)
